# --------------------------------------------------------
# modified from Hora

import torch
import numpy as np
from isaacgym import gymtorch
from isaacgym.torch_utils import torch_rand_float, quat_from_angle_axis, quat_mul, tensor_clamp, to_torch, quat_conjugate, quat_apply
from hora.tasks.leap_hand_hora import LeapHandHora
from isaacgym import gymapi
import transforms3d
import os


class LeapHandGrasp(LeapHandHora):
    def __init__(self, config, sim_device, graphics_device_id, headless):
        super().__init__(config, sim_device=sim_device, graphics_device_id=graphics_device_id, headless=headless)
        self.saved_grasping_states = torch.zeros((0, 23), dtype=torch.float, device=self.device)
        
        self.change_gravity_dir = config['env'].get('changeGravityDir')
        
        
        if self.hand_type == 'leap':
            self.canonical_pose = [
                1.244, 0.082, 0.265, 0.298, 
                1.163, 1.104, 0.953, -0.138, 
                1.096, 0.005, 0.080, 0.150, 
                1.337, 0.029, 0.285, 0.317,
            ]
            
        else:
            self.canonical_pose = [
                0.082, 1.244, 0.265, 0.298, 1.104, 1.163, 0.953, -0.138, 
                0.005, 1.096, 0.080, 0.150, 0.029, 1.337, 0.285, 0.317,
            ]
        
        self.x_unit_tensor = to_torch([1, 0, 0], dtype=torch.float, device=self.device).repeat((self.num_envs, 1))
        self.y_unit_tensor = to_torch([0, 1, 0], dtype=torch.float, device=self.device).repeat((self.num_envs, 1))
        self.z_unit_tensor = to_torch([0, 0, 1], dtype=torch.float, device=self.device).repeat((self.num_envs, 1))
        
        
        self.customize_grasp_pose = config['env'].get('customizeGraspPose', False)
        self.use_preoptimized_grasp_pose = config['env'].get('usePreoptimizedGraspPose', False)
        self.preoptimized_grasp_pose_fn = config['env'].get('preoptimizedGraspPoseFn', '')
        self.use_original_leap_pose = config['env'].get('useOriginalLeapPose', False)
        
        self.hand_grasp_facing_dir = config['env'].get('handGraspFacingDir', 'palm_down')
        
        if self.use_preoptimized_grasp_pose and len(self.preoptimized_grasp_pose_fn) > 0:
            print(f"Loading preoptimized grasp pose from {self.preoptimized_grasp_pose_fn}")
            self.preoptimized_grasp_pose = np.load(self.preoptimized_grasp_pose_fn, allow_pickle=True)
            self.tot_hand_qpos = []
            self.tot_obj_quat = []
            self.tot_obj_trans = []
            
            
            if self.hand_type == 'leap' and self.use_original_leap_pose:
                # self.hand_joint_names = ['1', '0', '2', '3', '5', '4', '6', '7', '9', '8', '10', '11', '12', '13', '14', '15']
                self.hand_joint_names = ['1', '0', '2', '3', '12', '13', '14', '15', '5', '4', '6', '7', '9', '8', '10', '11']
                rot_names = ['WRJ0rx', 'WRJ0ry', 'WRJ0rz']
                translation_names = ['WRJ0x', 'WRJ0y', 'WRJ0z']
            else:
                self.hand_joint_names =  ['joint_0.0', 'joint_1.0', 'joint_2.0', 'joint_3.0', 'joint_12.0', 'joint_13.0', 'joint_14.0', 'joint_15.0', 'joint_4.0', 'joint_5.0', 'joint_6.0', 'joint_7.0', 'joint_8.0', 'joint_9.0', 'joint_10.0', 'joint_11.0']
                
                rot_names = ['WRJRx', 'WRJRy', 'WRJRz']
                translation_names = ['WRJTx', 'WRJTy', 'WRJTz']
                
            
            for idx in range(len(self.preoptimized_grasp_pose)):
                cur_preoptimized_grasp_pose = self.preoptimized_grasp_pose[idx]
                cur_preoptimized_grasp_qpos = cur_preoptimized_grasp_pose['qpos']
                
                hand_rot = np.array(transforms3d.euler.euler2quat(*[cur_preoptimized_grasp_qpos[name] for name in rot_names]))
                hand_rot = hand_rot[[1, 2, 3, 0]]
                hand_rot_th = torch.from_numpy(hand_rot).float().to(self.device) # (4, )
                # delta_rot = down_rot * cur_rot^-1
                
                hand_trans = np.array([ cur_preoptimized_grasp_qpos[name]  for name in translation_names ], dtype=np.float32)
                hand_trans_th = torch.from_numpy(hand_trans).float().to(self.device) # (3, )
                
                ###### Get the hand rot quat ######
                hand_rot_quat = self.hand_pose.r # gymapi.Quat.from_axis_angle(gymapi.Vec3(0, 1, 0), np.pi / 2) * gymapi.Quat.from_axis_angle(gymapi.Vec3(1, 0, 0), np.pi / 2)
                hand_rot_quat = [hand_rot_quat.x, hand_rot_quat.y, hand_rot_quat.z, hand_rot_quat.w]
                hand_rot_quat = torch.tensor(hand_rot_quat, dtype=torch.float32).to(self.device).unsqueeze(0)
                delta_hand_rot = quat_mul(hand_rot_quat, quat_conjugate(hand_rot_th.unsqueeze(0))).squeeze(0)
                self.tot_obj_quat.append(delta_hand_rot)
                ###### Get the hand rot quat ######
                
                
                ###### Get the delta translations #######
                hand_final_trans = self.hand_pose.p # torch.tensor([0, 0, 0.5], dtype=torch.float32).to(self.device)
                hand_final_trans = torch.tensor(
                    [ hand_final_trans.x, hand_final_trans.y, hand_final_trans.z ], dtype=torch.float32
                ).to(self.device)
                delta_hand_trans = hand_final_trans - quat_apply(delta_hand_rot, hand_trans_th)
                self.tot_obj_trans.append(delta_hand_trans)
                ###### Get the delta translations #######
                
                
                cur_preoptimized_grasp_qpos = [
                    cur_preoptimized_grasp_qpos[cur_joint_nm] for cur_joint_nm in self.hand_joint_names
                ]
                cur_preoptimized_grasp_qpos = np.array(cur_preoptimized_grasp_qpos, dtype=np.float32) # (nn_dofs, )
                
                
                self.tot_hand_qpos.append(cur_preoptimized_grasp_qpos) 
            self.tot_hand_qpos = np.stack(self.tot_hand_qpos, axis=0) # (num_samples, nn_dofs)
            self.tot_hand_qpos = torch.from_numpy(self.tot_hand_qpos).float().to(self.device)
            self.tot_obj_quat = torch.stack(self.tot_obj_quat, dim=0) # (num_samples, 4)
            self.tot_obj_trans = torch.stack(self.tot_obj_trans, dim=0)
            
            
            if self.hand_type == 'leap' and (not self.use_original_leap_pose):
                ##### adapte allegro qpos to leap qpos #####
                allegro_qpos_to_leap_qpos_idxes = [1, 0, 2, 3]
                allegro_qpos_to_leap_qpos_idxes_2 = [cur_idx + 4 for cur_idx in allegro_qpos_to_leap_qpos_idxes]
                allegro_qpos_to_leap_qpos_idxes_3 = [cur_idx + 8 for cur_idx in allegro_qpos_to_leap_qpos_idxes]
                allegro_qpos_to_leap_qpos_idxes_4 = [cur_idx + 12 for cur_idx in allegro_qpos_to_leap_qpos_idxes]
                allegro_qpos_to_leap_qpos_idxes = allegro_qpos_to_leap_qpos_idxes + allegro_qpos_to_leap_qpos_idxes_2 + allegro_qpos_to_leap_qpos_idxes_3 + allegro_qpos_to_leap_qpos_idxes_4
                allegro_qpos_to_leap_qpos_idxes = torch.tensor(allegro_qpos_to_leap_qpos_idxes, dtype=torch.long, device=self.device)
                self.tot_hand_qpos = self.tot_hand_qpos[:, allegro_qpos_to_leap_qpos_idxes] # (num_samples, nn_dofs)
            
        if self.change_gravity_dir:
            self.chagne_gravity_episode = 50 # 200
            self.unit_tensors =  torch.tensor(
                [
                    [1, 0, 0], [0, 1, 0], [0, 0, 1],
                    [-1, 0, 0], [0, -1, 0], [0, 0, -1]
                ], dtype=torch.float32, device=self.device
            )
            self.global_step = 0
        
            self.ori_root_state_tensor = None
            self.obj_pos_deviation_threshold = 0.04
            
        self.saved_1k_samples = False
        
        
        grasp_cache_sv_root = "."
        self.grasp_cache_sv_root = grasp_cache_sv_root
        
        
        
        if len(self.specified_obj_idx) > 0:
            str_obj_idx = [ str(cur_idx) for cur_idx in self.specified_obj_idx ]
            str_obj_idx = "_".join(str_obj_idx)
            grasp_cache_sv_name_50k = f'cache/{self.grasp_cache_name}_{str_obj_idx}_grasp_50k_s{str(self.base_obj_scale).replace(".", "")}.npy'
        else:
            grasp_cache_sv_name_50k = f'cache/{self.grasp_cache_name}_grasp_50k_s{str(self.base_obj_scale).replace(".", "")}.npy'
        grasp_cache_sv_fn = os.path.join(grasp_cache_sv_root, grasp_cache_sv_name_50k)
        

    def reset_idx(self, env_ids):
        if self.randomize_mass:
            lower, upper = self.randomize_mass_lower, self.randomize_mass_upper
            for env_id in env_ids:
                env = self.envs[env_id]
                handle = self.gym.find_actor_handle(env, 'object')
                prop = self.gym.get_actor_rigid_body_properties(env, handle)
                for p in prop:
                    p.mass = np.random.uniform(lower, upper)
                self.gym.set_actor_rigid_body_properties(env, handle, prop)
                self._update_priv_buf(env_id=env_id, name='obj_mass', value=prop[0].mass, lower=0, upper=0.2)
        else:
            for env_id in env_ids:
                env = self.envs[env_id]
                handle = self.gym.find_actor_handle(env, 'object')
                prop = self.gym.get_actor_rigid_body_properties(env, handle)
                self._update_priv_buf(env_id=env_id, name='obj_mass', value=prop[0].mass, lower=0, upper=0.2)

        if self.randomize_pd_gains:
            self.p_gain[env_ids] = torch_rand_float(
                self.randomize_p_gain_lower, self.randomize_p_gain_upper, (len(env_ids), self.num_actions),
                device=self.device).squeeze(1)
            self.d_gain[env_ids] = torch_rand_float(
                self.randomize_d_gain_lower, self.randomize_d_gain_upper, (len(env_ids), self.num_actions),
                device=self.device).squeeze(1)

        # generate random values
        rand_floats = torch_rand_float(-1.0, 1.0, (len(env_ids), self.num_allegro_hand_dofs * 2 + 5), device=self.device)

        # reset rigid body forces
        self.rb_forces[env_ids, :, :] = 0.0
        success = self.progress_buf[env_ids] == self.max_episode_length
        all_states = torch.cat([
            self.allegro_hand_dof_pos, self.root_state_tensor[self.object_indices, :7]
        ], dim=1)
        self.saved_grasping_states = torch.cat([self.saved_grasping_states, all_states[env_ids][success]])
        print('current cache size:', self.saved_grasping_states.shape[0])
        # print(f'grasp_cache_name: {self.grasp_cache_name}')
        
        
        
        if len(self.saved_grasping_states) >= 5e4:
            if len(self.specified_obj_idx) > 0:
                str_obj_idx = [ str(cur_idx) for cur_idx in self.specified_obj_idx ]
                str_obj_idx = "_".join(str_obj_idx)
                name = f'cache/{self.grasp_cache_name}_{str_obj_idx}_grasp_50k_s{str(self.base_obj_scale).replace(".", "")}.npy'
            else:
                name = f'cache/{self.grasp_cache_name}_grasp_50k_s{str(self.base_obj_scale).replace(".", "")}.npy'
            
            name = os.path.join(self.grasp_cache_sv_root, name)
                
            np.save(name, self.saved_grasping_states[:50000].cpu().numpy())
            print(f"save grasp cache to {name}")
            exit()
        elif len(self.saved_grasping_states) > 1000 and (not self.saved_1k_samples):
            if len(self.specified_obj_idx) > 0:
                str_obj_idx = [ str(cur_idx) for cur_idx in self.specified_obj_idx ]
                str_obj_idx = "_".join(str_obj_idx)
                name = f'cache/{self.grasp_cache_name}_{str_obj_idx}_grasp_1k_s{str(self.base_obj_scale).replace(".", "")}.npy'
            else:
                name = f'cache/{self.grasp_cache_name}_grasp_1k_s{str(self.base_obj_scale).replace(".", "")}.npy'
                
            name = os.path.join(self.grasp_cache_sv_root, name)
                
            np.save(name, self.saved_grasping_states[:1000].cpu().numpy())
            print(f"save grasp cache to {name}")
            self.saved_1k_samples = True

        # reset object #
        self.root_state_tensor[self.object_indices[env_ids]] = self.object_init_state[env_ids].clone()
        self.root_state_tensor[self.object_indices[env_ids], 0:2] = self.object_init_state[env_ids, 0:2]
        self.root_state_tensor[self.object_indices[env_ids], self.up_axis_idx] = self.object_init_state[env_ids, self.up_axis_idx]
        
        
        if self.use_preoptimized_grasp_pose:
            rand_sampled_qpos_idxes = torch.randint(0, self.tot_hand_qpos.shape[0], (len(env_ids),), device=self.device)
            new_object_rot = randomize_rotation(rand_floats[:, 3], rand_floats[:, 4], self.x_unit_tensor[env_ids], self.y_unit_tensor[env_ids])

            ###### Use the preset object rotations #######
            rand_sampled_obj_quat = self.tot_obj_quat[rand_sampled_qpos_idxes]
            new_object_rot[:, :] = rand_sampled_obj_quat
            
            rand_sampled_obj_pos = self.tot_obj_trans[rand_sampled_qpos_idxes]
            if self.hand_type == 'leap' and (not self.use_original_leap_pose):
                rand_sampled_obj_pos[..., 1] = rand_sampled_obj_pos[..., 1] - 0.07
                rand_sampled_obj_pos[..., 0] = rand_sampled_obj_pos[..., 0] - 0.02
            self.root_state_tensor[self.object_indices[env_ids], 0:3] = rand_sampled_obj_pos.clone()
            ###### Use the preset object rotations #######
            
        else:
            new_object_rot = randomize_rotation(rand_floats[:, 3], rand_floats[:, 4], self.x_unit_tensor[env_ids], self.y_unit_tensor[env_ids])
            new_object_rot[:] = 0
            new_object_rot[:, -1] = 1
            
            # customize grasp pose # 
            if self.customize_grasp_pose:
                rand_floats[:, 4] = 0.5
                # rand_floats[:, 3] = 0.0
                
                # rand_floats[:, 3] = 0.0
                # rand_floats[:, 4] = 0.0
                new_object_rot = randomize_rotation(rand_floats[:, 3], rand_floats[:, 4], self.z_unit_tensor[env_ids], self.y_unit_tensor[env_ids])
        
        # root state tensor #
        
        ##### ==== set root state tensors for the hand ==== #####
        if self.omni_wrist_ornt:
            rnd_float = torch_rand_float(-1.0, 1.0, (len(env_ids), 3), device=self.device)
            if self.hand_grasp_facing_dir == 'palm_down':
                rnd_float[..., :] = 0.0
            elif self.hand_grasp_facing_dir == 'palm_up':
                rnd_float[..., 0] = 0.0; rnd_float[..., 1] = -1; rnd_float[..., 2] = 0.0
            elif self.hand_grasp_facing_dir == 'base_up':
                rnd_float[..., 0] = 0.5; rnd_float[..., 1] = 1.5; rnd_float[..., 2] = 0.0
            elif self.hand_grasp_facing_dir == 'base_down':
                rnd_float[..., 0] = 1.5; rnd_float[..., 1] = 1.5; rnd_float[..., 2] = 0.0
            elif self.hand_grasp_facing_dir == 'thumb_up':
                rnd_float[..., 0] = 0; rnd_float[..., 1] = 1.5; rnd_float[..., 2] = 0.0
            else:
                rnd_float[..., 0] = 1.0; rnd_float[..., 1] = 1.5; rnd_float[..., 2] = 0.0
            new_rnd_rot = randomize_rotation_rpy(rnd_float[:, 0], rnd_float[:, 1], rnd_float[:, 2], self.x_unit_tensor[env_ids], self.y_unit_tensor[env_ids], self.z_unit_tensor[env_ids])
            q_h = self.hand_pose_tsr[ 3:].unsqueeze(0).repeat(len(env_ids), 1).contiguous()
            q_h_new = quat_mul(new_rnd_rot, q_h)
            self.root_state_tensor[self.hand_indices[env_ids], 3:7] = q_h_new
            hand_indices = self.hand_indices[env_ids].to(torch.int32)
            self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self.root_state_tensor), gymtorch.unwrap_tensor(hand_indices), len(env_ids))
            
            new_object_rot = quat_mul(new_rnd_rot, new_object_rot)
            t_o = self.root_state_tensor[self.object_indices[env_ids], 0:3].clone()
            t_h = self.hand_pose_tsr[:3].unsqueeze(0).repeat(len(env_ids), 1).contiguous() # s_len x 3
            t_o_new = t_h + quat_apply(new_rnd_rot, t_o - t_h) 
            self.root_state_tensor[self.object_indices[env_ids], 0:3] = t_o_new.clone()
            self.root_state_tensor[self.object_indices[env_ids], 3:7] = new_object_rot.clone()
        ##### ==== set root state tensors for the hand ==== #####
        
        
        
        self.root_state_tensor[self.object_indices[env_ids], 3:7] = new_object_rot
        self.root_state_tensor[self.object_indices[env_ids], 7:13] = torch.zeros_like(
            self.root_state_tensor[self.object_indices[env_ids], 7:13])

        object_indices = torch.unique(self.object_indices[env_ids]).to(torch.int32)
        self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self.root_state_tensor),
                                                     gymtorch.unwrap_tensor(object_indices), len(object_indices))
        ##### ==== set root state tensors for the object ==== #####
        
        
        
        
        if self.use_preoptimized_grasp_pose: 
            
            rand_sampled_qpos = self.tot_hand_qpos[rand_sampled_qpos_idxes] # nn-envs x nn_dofs
            if self.hand_type == 'leap': # try set  the leap's randomized floats to 0.05 as well #
                # rand_sampled_qpos = rand_sampled_qpos + 0.25 * rand_floats[:, 5:5 + self.num_allegro_hand_dofs]
                rand_sampled_qpos = rand_sampled_qpos + 0.05 * rand_floats[:, 5:5 + self.num_allegro_hand_dofs]
            else:
                rand_sampled_qpos = rand_sampled_qpos + 0.05 * rand_floats[:, 5:5 + self.num_allegro_hand_dofs]
            rand_sampled_qpos = tensor_clamp(rand_sampled_qpos, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
            pos = rand_sampled_qpos
        else:
            pos = to_torch(self.canonical_pose, device=self.device)[None].repeat(len(env_ids), 1)
            pos += 0.25 * rand_floats[:, 5:5 + self.num_allegro_hand_dofs]
            # pos += rand_floats[:, 5:5 + self.num_allegro_hand_dofs]
            pos = tensor_clamp(pos, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)

        
        
        self.allegro_hand_dof_pos[env_ids, :] = pos
        self.allegro_hand_dof_vel[env_ids, :] = 0
        self.prev_targets[env_ids, :self.num_allegro_hand_dofs] = pos
        self.cur_targets[env_ids, :self.num_allegro_hand_dofs] = pos

        hand_indices = self.hand_indices[env_ids].to(torch.int32)
        if not self.torque_control:
            self.gym.set_dof_position_target_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self.prev_targets),
                                                            gymtorch.unwrap_tensor(hand_indices), len(env_ids))
        self.gym.set_dof_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self.dof_state),
                                              gymtorch.unwrap_tensor(hand_indices), len(env_ids))

        self.progress_buf[env_ids] = 0
        self.obs_buf[env_ids] = 0
        self.rb_forces[env_ids] = 0
        self.priv_info_buf[env_ids, 0:3] = 0
        self.proprio_hist_buf[env_ids] = 0

        self.at_reset_buf[env_ids] = 1
        
        if self.change_gravity_dir:
            if self.ori_root_state_tensor is None:
                self.ori_root_state_tensor = self.root_state_tensor.clone()
            else:
                self.ori_root_state_tensor[env_ids] = self.root_state_tensor[env_ids].clone()


    def pre_physics_step(self, actions):
        self.actions = actions.clone().to(self.device)
        targets = self.prev_targets + 1 / 24 * self.actions
        self.cur_targets[:] = tensor_clamp(targets, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
        self.prev_targets[:] = self.cur_targets.clone()
        self.object_rot_prev[:] = self.object_rot
        self.object_pos_prev[:] = self.object_pos

        if self.force_scale > 0.0:
            self.rb_forces *= torch.pow(self.force_decay, self.dt / self.force_decay_interval)
            # apply new forces
            obj_mass = to_torch(
                [self.gym.get_actor_rigid_body_properties(env, self.gym.find_actor_handle(env, 'object'))[0].mass for
                 env in self.envs], device=self.device)
            prob = self.random_force_prob_scalar
            force_indices = (torch.less(torch.rand(self.num_envs, device=self.device), prob)).nonzero()
            self.rb_forces[force_indices, self.object_rb_handles, :] = torch.randn(
                self.rb_forces[force_indices, self.object_rb_handles, :].shape,
                device=self.device) * obj_mass[force_indices, None] * self.force_scale
            
            
            if self.change_gravity_dir:
                obj_mass = to_torch(
                    [self.gym.get_actor_rigid_body_properties(env, self.gym.find_actor_handle(env, 'object'))[0].mass for
                    env in self.envs], device=self.device)
                gravity_scale = 9.8
                
                cur_gravity_dir_idx = (self.global_step // self.chagne_gravity_episode) % self.unit_tensors.size(0)
                
                cur_gravity_tsr = self.unit_tensors[cur_gravity_dir_idx]
                rb_force = cur_gravity_tsr.unsqueeze(0).repeat(obj_mass.size(0), 1) * obj_mass.unsqueeze(1) * gravity_scale
                self.rb_forces[:, self.object_rb_handles, :] += rb_force.unsqueeze(1)
            self.gym.apply_rigid_body_force_tensors(self.sim, gymtorch.unwrap_tensor(self.rb_forces), None, gymapi.ENV_SPACE)
            
        elif self.change_gravity_dir:
            obj_mass = to_torch(
                [self.gym.get_actor_rigid_body_properties(env, self.gym.find_actor_handle(env, 'object'))[0].mass for
                 env in self.envs], device=self.device)
            gravity_scale = 9.8
            
            cur_gravity_dir_idx = (self.global_step // self.chagne_gravity_episode) % self.unit_tensors.size(0)
            
            # cur_gravity_dir_idx = 5
            
            cur_gravity_tsr = self.unit_tensors[cur_gravity_dir_idx]
            rb_force = cur_gravity_tsr.unsqueeze(0).repeat(obj_mass.size(0), 1) * obj_mass.unsqueeze(1) * gravity_scale
            self.rb_forces[:, self.object_rb_handles, :] = rb_force.unsqueeze(1)
            self.gym.apply_rigid_body_force_tensors(self.sim, gymtorch.unwrap_tensor(self.rb_forces), None, gymapi.ENV_SPACE)
        
    



    def compute_reward(self, actions):
        def list_intersect(li, hash_num):
            # 17 is the object index
            # 4, 8, 12, 16 are fingertip index
            # return number of contact with obj_id
            obj_id = 17
            query_list = [obj_id * hash_num + 4, obj_id * hash_num + 8, obj_id * hash_num + 12, obj_id * hash_num + 16]
            return len(np.intersect1d(query_list, li))
        assert self.device == 'cpu'
        contacts = [self.gym.get_env_rigid_contacts(env) for env in self.envs]
        contact_list = [list_intersect(np.unique([c[2] * 10000 + c[3] for c in contact]), 10000) for contact in contacts]
        contact_condition = to_torch(contact_list, device=self.device)

        obj_pos = self.rigid_body_states[:, [-1], :3]
        finger_pos = self.rigid_body_states[:, [4, 8, 12, 16], :3]
        # the sampled pose need to satisfy (check 1 here):
        # 1) all fingertips is nearby objects
        cond1 = (torch.sqrt(((obj_pos - finger_pos) ** 2).sum(-1)) < 0.1).all(-1)
        # 2) at least two fingers are in contact with object
        cond2 = contact_condition >= 2
        # 3) object does not fall after a few iterations
        # 0.645 for internal allegro
        # 0.625 for public allegro
        if self.change_gravity_dir: 
            
            ori_obj_pos = self.ori_root_state_tensor[self.object_indices, :3]
            cur_obj_pos = obj_pos[:, -1, :]
            diff_ori_obj_pos_w_cur_obj_pos = torch.sqrt(torch.sum((ori_obj_pos - cur_obj_pos) ** 2, dim=-1))
            cond3 = diff_ori_obj_pos_w_cur_obj_pos < self.obj_pos_deviation_threshold
            # cond3 = torch.greater(obj_pos[:, -1, -1], self.reset_z_threshold)
        else:
            cond3 = torch.greater(obj_pos[:, -1, -1], self.reset_z_threshold)
    
        cond = cond1.float() * cond2.float() * cond3.float()
        # reset if any of the above condition does not hold
        self.reset_buf[cond < 1] = 1
        self.reset_buf[self.progress_buf >= self.max_episode_length] = 1
        
        if self.change_gravity_dir:
            self.global_step += 1
            pass


@torch.jit.script
def randomize_rotation_rpy(rand0, rand1, rand2, x_unit_tensor, y_unit_tensor, z_unit_tensor):
    return quat_mul(quat_mul(quat_from_angle_axis(rand0 * np.pi, x_unit_tensor), quat_from_angle_axis(rand1 * np.pi, y_unit_tensor)), quat_from_angle_axis(rand2 * np.pi, z_unit_tensor))


@torch.jit.script
def randomize_rotation(rand0, rand1, x_unit_tensor, y_unit_tensor):
    return quat_mul(quat_from_angle_axis(rand0 * np.pi, x_unit_tensor), quat_from_angle_axis(rand1 * np.pi, y_unit_tensor))
